{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "OpenSpiel - DeepCFR",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6fs6SIL3QgE-",
        "colab_type": "text"
      },
      "source": [
        "# OpenSpiel: Deep CFR in PyTorch\n",
        "In this notebook we present an example of how to use OpenSpiel with PyTorch as the learning library. This is possible since most of the components in OpenSpiel are agnostic of the learning library used.\n",
        "\n",
        "In this notebook we implement the Deep Counterfactual Regret algorithm in PyTorch on the game of Kuhn Poker. Adapted from [this example](https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/examples/deep_cfr.py).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_4pUODpi0Lnl",
        "colab_type": "text"
      },
      "source": [
        "## Build and Install OpenSpiel\n",
        "Clone the latest version of OpenSpiel from Github."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2e9hMgKSg3e-",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title\n",
        "%cd /usr/local\n",
        "!git clone https://github.com/deepmind/open_spiel;\n",
        "%cd open_spiel"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3dvGiEYFChct",
        "colab_type": "text"
      },
      "source": [
        "Install the requirements and build the library."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "x2iMd1KbT5gO",
        "colab": {},
        "cellView": "form"
      },
      "source": [
        "#@title\n",
        "!./install.sh\n",
        "!pip install -r requirements.txt;\n",
        "!mkdir -p /usr/local/open_spiel/build\n",
        "%cd /usr/local/open_spiel/build\n",
        "!CXX=g++ cmake -DPython_TARGET_VERSION=3.6 -DCMAKE_CXX_COMPILER=${CXX} ../open_spiel;\n",
        "!make -j$(nproc);"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UqqrOKZGCypH",
        "colab_type": "text"
      },
      "source": [
        "Set up the PYTHONPATH to enable the Python bindings for the current instance."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KIFkxD2NhK7J",
        "colab_type": "code",
        "colab": {},
        "cellView": "form"
      },
      "source": [
        "#@title\n",
        "%cd /content\n",
        "import sys\n",
        "sys.path.append('/usr/local/open_spiel')\n",
        "sys.path.append('/usr/local/open_spiel/build/python')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ss3qlgexEhV9",
        "colab_type": "text"
      },
      "source": [
        "## Import dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TJaIzvZLlelO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from open_spiel.python import policy\n",
        "from open_spiel.python.algorithms import exploitability\n",
        "from open_spiel.python.algorithms.deep_cfr import AdvantageMemory, StrategyMemory, ReservoirBuffer\n",
        "import pyspiel\n",
        "from absl import logging\n",
        "import collections\n",
        "import numpy as np\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qnFDt18N2ZjA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def reset_params(m):\n",
        "  if isinstance(m, nn.Linear):\n",
        "    m.reset_parameters()\n",
        "\n",
        "# Defining fully-connected network to use as policy and advantage network\n",
        "class MLP(nn.Module):\n",
        "  def __init__(self, layers):\n",
        "    \"\"\"\n",
        "      Creates a MLP with hidden units defined by `layers`.\n",
        "    \"\"\"\n",
        "    super(MLP, self).__init__()\n",
        "    layers = [nn.Linear(layers[i], layers[i+1]) for i in range(len(layers[:-1]))]\n",
        "    self._model = nn.Sequential(*layers)\n",
        "    \n",
        "  def forward(self, input):\n",
        "    return self._model(input)\n",
        "  \n",
        "  def reset(self):\n",
        "    self._model.apply(reset_params)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P35t4vpCapIz",
        "colab_type": "text"
      },
      "source": [
        "## Implement DeepCFR\n",
        "We follow the structure in the [Tensorflow implementation](https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/algorithms/deep_cfr.py), rewiting the optimization steps in PyTorch. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TMXvG7O80u50",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class DeepCFR(policy.Policy):\n",
        "  \"\"\"Implements a solver for the Deep CFR Algorithm with PyTorch.\n",
        "\n",
        "  See https://arxiv.org/abs/1811.00164.\n",
        "\n",
        "  Define all networks and sampling buffers/memories.  Derive losses & learning\n",
        "  steps. Initialize the game state and algorithmic variables.\n",
        "\n",
        "  Note: batch sizes default to `None` implying that training over the full\n",
        "        dataset in memory is done by default.  To sample from the memories you\n",
        "        may set these values to something less than the full capacity of the\n",
        "        memory.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self,\n",
        "               game,\n",
        "               policy_network_layers=(256, 256),\n",
        "               advantage_network_layers=(128, 128),\n",
        "               num_iterations: int = 100,\n",
        "               num_traversals: int = 20,\n",
        "               learning_rate:float = 1e-4,\n",
        "               batch_size_advantage=None,\n",
        "               batch_size_strategy=None,\n",
        "               memory_capacity: int =int(1e6),\n",
        "               policy_network_train_steps: int = 1,\n",
        "               advantage_network_train_steps: int = 1,\n",
        "               reinitialize_advantage_networks: bool = True):\n",
        "    \"\"\"Initialize the Deep CFR algorithm.\n",
        "\n",
        "    Args:\n",
        "      game: Open Spiel game.\n",
        "      policy_network_layers: (list[int]) Layer sizes of strategy net MLP.\n",
        "      advantage_network_layers: (list[int]) Layer sizes of advantage net MLP.\n",
        "      num_iterations: (int) Number of training iterations.\n",
        "      num_traversals: (int) Number of traversals per iteration.\n",
        "      learning_rate: (float) Learning rate.\n",
        "      batch_size_advantage: (int or None) Batch size to sample from advantage\n",
        "        memories.\n",
        "      batch_size_strategy: (int or None) Batch size to sample from strategy\n",
        "        memories.\n",
        "      memory_capacity: Number af samples that can be stored in memory.\n",
        "      policy_network_train_steps: Number of policy network training steps (per\n",
        "        iteration).\n",
        "      advantage_network_train_steps: Number of advantage network training steps\n",
        "        (per iteration).\n",
        "      reinitialize_advantage_networks: Whether to re-initialize the\n",
        "        advantage network before training on each iteration.\n",
        "    \"\"\"\n",
        "    all_players = list(range(game.num_players()))\n",
        "    super(DeepCFR, self).__init__(game, all_players)\n",
        "    self._game = game\n",
        "    if game.get_type().dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS:\n",
        "      # `_traverse_game_tree` does not take into account this option.\n",
        "      raise ValueError(\"Simulatenous games are not supported.\")\n",
        "    self._batch_size_advantage = batch_size_advantage\n",
        "    self._batch_size_strategy = batch_size_strategy\n",
        "    self._policy_network_train_steps = policy_network_train_steps\n",
        "    self._advantage_network_train_steps = advantage_network_train_steps\n",
        "    self._num_players = game.num_players()\n",
        "    self._root_node = self._game.new_initial_state()\n",
        "    self._embedding_size = len(\n",
        "        self._root_node.information_state_tensor(0))\n",
        "    self._num_iterations = num_iterations\n",
        "    self._num_traversals = num_traversals\n",
        "    self._reinitialize_advantage_networks = reinitialize_advantage_networks\n",
        "    self._num_actions = game.num_distinct_actions()\n",
        "    self._iteration = 1\n",
        "\n",
        "    # Define strategy network, loss & memory.\n",
        "    self._strategy_memories = ReservoirBuffer(memory_capacity)\n",
        "    self._policy_network = MLP(\n",
        "      [self._embedding_size] + list(policy_network_layers) + [self._num_actions])\n",
        "    # Illegal actions are handled in the traversal code where expected payoff\n",
        "    # and sampled regret is computed from the advantage networks.\n",
        "    self._policy_sm = nn.Softmax(dim=-1)\n",
        "    self._loss_policy = nn.MSELoss()\n",
        "    self._optimizer_policy = torch.optim.Adam(\n",
        "                  self._policy_network.parameters(), lr=learning_rate)\n",
        "\n",
        "    # Define advantage network, loss & memory. (One per player)\n",
        "    self._advantage_memories = [\n",
        "        ReservoirBuffer(memory_capacity) for _ in range(self._num_players)\n",
        "    ]\n",
        "    self._advantage_networks = [\n",
        "        MLP([self._embedding_size] + list(advantage_network_layers) + \n",
        "            [self._num_actions])\n",
        "        for _ in range(self._num_players)\n",
        "    ]\n",
        "    self._loss_advantages = nn.MSELoss(reduction='mean')\n",
        "    self._optimizer_advantages = []\n",
        "    for p in range(self._num_players):\n",
        "      self._optimizer_advantages.append(torch.optim.Adam(\n",
        "            self._advantage_networks[p].parameters(), lr=learning_rate))\n",
        "\n",
        "  @property\n",
        "  def advantage_buffers(self):\n",
        "    return self._advantage_memories\n",
        "\n",
        "  @property\n",
        "  def strategy_buffer(self):\n",
        "    return self._strategy_memories\n",
        "\n",
        "  def clear_advantage_buffers(self):\n",
        "    for p in range(self._num_players):\n",
        "      self._advantage_memories[p].clear()\n",
        "\n",
        "  def reinitialize_advantage_network(self, player):\n",
        "    self._advantage_networks[player].reset()\n",
        "\n",
        "  def reinitialize_advantage_networks(self):\n",
        "    for p in range(self._num_players):\n",
        "      self.reinitialize_advantage_network(p)\n",
        "\n",
        "  def solve(self):\n",
        "    \"\"\"Solution logic for Deep CFR.\n",
        "\n",
        "    Traverses the game tree, while storing the transitions for training\n",
        "    advantage and policy networks. \n",
        "\n",
        "    Returns:\n",
        "      1. (nn.Module) Instance of the trained policy network for inference. \n",
        "      2. (list of floats) Advantage network losses for \n",
        "        each player during each iteration.\n",
        "      3. (float) Policy loss.   \n",
        "    \"\"\"\n",
        "    advantage_losses = collections.defaultdict(list)\n",
        "    for _ in range(self._num_iterations):\n",
        "      for p in range(self._num_players):\n",
        "        for _ in range(self._num_traversals):\n",
        "          self._traverse_game_tree(self._root_node, p)\n",
        "        if self._reinitialize_advantage_networks:\n",
        "          # Re-initialize advantage network for player and train from scratch.\n",
        "          self.reinitialize_advantage_network(p)\n",
        "        # Re-initialize advantage networks and train from scratch.\n",
        "        advantage_losses[p].append(self._learn_advantage_network(p))\n",
        "      self._iteration += 1\n",
        "      # Train policy network.\n",
        "    policy_loss = self._learn_strategy_network()\n",
        "    return self._policy_network, advantage_losses, policy_loss\n",
        "\n",
        "  def _traverse_game_tree(self, state, player):\n",
        "    \"\"\"Performs a traversal of the game tree.\n",
        "\n",
        "    Over a traversal the advantage and strategy memories are populated with\n",
        "    computed advantage values and matched regrets respectively.\n",
        "\n",
        "    Args:\n",
        "      state: Current OpenSpiel game state.\n",
        "      player: (int) Player index for this traversal.\n",
        "\n",
        "    Returns:\n",
        "      (float) Recursively returns expected payoffs for each action.\n",
        "    \"\"\"\n",
        "    expected_payoff = collections.defaultdict(float)\n",
        "    if state.is_terminal():\n",
        "      # Terminal state get returns.\n",
        "      return state.returns()[player]\n",
        "    elif state.is_chance_node():\n",
        "      # If this is a chance node, sample an action\n",
        "      action = np.random.choice([i[0] for i in state.chance_outcomes()])\n",
        "      return self._traverse_game_tree(state.child(action), player)\n",
        "    elif state.current_player() == player:\n",
        "      sampled_regret = collections.defaultdict(float)\n",
        "      # Update the policy over the info set & actions via regret matching.\n",
        "      advantages, strategy = self._sample_action_from_advantage(state, player)\n",
        "      for action in state.legal_actions():\n",
        "        expected_payoff[action] = self._traverse_game_tree(\n",
        "            state.child(action), player)\n",
        "      for action in state.legal_actions():\n",
        "        sampled_regret[action] = expected_payoff[action]\n",
        "        for a_ in state.legal_actions():\n",
        "          sampled_regret[action] -= strategy[a_] * expected_payoff[a_]\n",
        "      sampled_regret_arr = [0] * self._num_actions\n",
        "      for action in sampled_regret:\n",
        "        sampled_regret_arr[action] = sampled_regret[action]\n",
        "      self._advantage_memories[player].add(\n",
        "          AdvantageMemory(state.information_state_tensor(),\n",
        "                          self._iteration, sampled_regret_arr, action))\n",
        "      return max(expected_payoff.values())\n",
        "    else:\n",
        "      other_player = state.current_player()\n",
        "      _, strategy = self._sample_action_from_advantage(state, other_player)\n",
        "      # Recompute distribution for numerical errors.\n",
        "      probs = np.array(strategy)\n",
        "      probs /= probs.sum()\n",
        "      sampled_action = np.random.choice(range(self._num_actions), p=probs)\n",
        "      self._strategy_memories.add(\n",
        "          StrategyMemory(\n",
        "              state.information_state_tensor(other_player),\n",
        "              self._iteration, strategy))\n",
        "      return self._traverse_game_tree(state.child(sampled_action), player)\n",
        "\n",
        "  def _sample_action_from_advantage(self, state, player):\n",
        "    \"\"\"Returns an info state policy by applying regret-matching.\n",
        "\n",
        "    Args:\n",
        "      state: Current OpenSpiel game state.\n",
        "      player: (int) Player index over which to compute regrets.\n",
        "\n",
        "    Returns:\n",
        "      1. (list) Advantage values for info state actions indexed by action.\n",
        "      2. (list) Matched regrets, prob for actions indexed by action.\n",
        "    \"\"\"\n",
        "    info_state = state.information_state_tensor(player)\n",
        "    legal_actions = state.legal_actions(player)\n",
        "    with torch.no_grad():\n",
        "      state_tensor = torch.FloatTensor(np.expand_dims(info_state, axis=0))\n",
        "      advantages = self._advantage_networks[player](state_tensor)[0].numpy()\n",
        "    advantages = [max(0., advantage) for advantage in advantages]\n",
        "    cumulative_regret = np.sum(\n",
        "                            [advantages[action] for action in legal_actions])\n",
        "    matched_regrets = np.array([0.] * self._num_actions)\n",
        "    for action in legal_actions:\n",
        "      if cumulative_regret > 0.:\n",
        "        matched_regrets[action] = advantages[action] / cumulative_regret\n",
        "      else:\n",
        "        matched_regrets[action] = 1 / self._num_actions\n",
        "    return advantages, matched_regrets\n",
        "\n",
        "  def action_probabilities(self, state):\n",
        "    \"\"\"Computes the action probabilities for the current player \n",
        "      in the given state.\n",
        "    \n",
        "    Args:\n",
        "      state: (pyspiel.State)\n",
        "    \n",
        "    Returns:\n",
        "      (dict) action probabilities for a single batch.\"\"\"\n",
        "    cur_player = state.current_player()\n",
        "    legal_actions = state.legal_actions(cur_player)\n",
        "    info_state_vector = np.array(state.information_state_tensor())\n",
        "    if len(info_state_vector.shape) == 1:\n",
        "      info_state_vector = np.expand_dims(info_state_vector, axis=0)\n",
        "    with torch.no_grad():\n",
        "      logits = self._policy_network(torch.FloatTensor(info_state_vector))\n",
        "      probs = self._policy_sm(logits).numpy()\n",
        "    return {action: probs[0][action] for action in legal_actions}\n",
        "\n",
        "  def _learn_advantage_network(self, player):\n",
        "    \"\"\"Compute the loss on sampled transitions and perform a Q-network update.\n",
        "\n",
        "    If there are not enough elements in the buffer, no loss is computed and\n",
        "    `None` is returned instead.\n",
        "\n",
        "    Args:\n",
        "      player: (int) player index.\n",
        "\n",
        "    Returns:\n",
        "      (float) The average loss over the advantage network.\n",
        "    \"\"\"\n",
        "    for _ in range(self._advantage_network_train_steps):\n",
        "\n",
        "      if self._batch_size_advantage:\n",
        "        if self._batch_size_advantage > len(self._advantage_memories[player]):\n",
        "          ## Skip if there aren't enough samples\n",
        "          return None\n",
        "        samples = self._advantage_memories[player].sample(\n",
        "            self._batch_size_advantage)\n",
        "      else:\n",
        "        samples = self._advantage_memories[player]\n",
        "      info_states = []\n",
        "      advantages = []\n",
        "      iterations = []\n",
        "      for s in samples:\n",
        "        info_states.append(s.info_state)\n",
        "        advantages.append(s.advantage)\n",
        "        iterations.append([s.iteration])\n",
        "      # Ensure some samples have been gathered.\n",
        "      if not info_states:\n",
        "        return None\n",
        "      self._optimizer_advantages[player].zero_grad()\n",
        "      advantages = torch.FloatTensor(np.array(advantages))\n",
        "      iters = torch.FloatTensor(np.sqrt(np.array(iterations)))\n",
        "      outputs = self._advantage_networks[player](\n",
        "                        torch.FloatTensor(np.array(info_states)))\n",
        "      loss_advantages = self._loss_advantages(iters * outputs, \n",
        "                                              iters * advantages)\n",
        "      loss_advantages.backward()\n",
        "      self._optimizer_advantages[player].step()\n",
        "\n",
        "    return loss_advantages.detach().numpy()\n",
        "\n",
        "  def _learn_strategy_network(self):\n",
        "    \"\"\"Compute the loss over the strategy network.\n",
        "\n",
        "    Returns:\n",
        "      (float) The average loss obtained on this batch of transitions or `None`.\n",
        "    \"\"\"\n",
        "    for _ in range(self._policy_network_train_steps):\n",
        "      if self._batch_size_strategy:\n",
        "        if self._batch_size_strategy > len(self._strategy_memories):\n",
        "          ## Skip if there aren't enough samples\n",
        "          return None\n",
        "        samples = self._strategy_memories.sample(self._batch_size_strategy)\n",
        "      else:\n",
        "        samples = self._strategy_memories\n",
        "      info_states = []\n",
        "      action_probs = []\n",
        "      iterations = []\n",
        "      for s in samples:\n",
        "        info_states.append(s.info_state)\n",
        "        action_probs.append(s.strategy_action_probs)\n",
        "        iterations.append([s.iteration])\n",
        "\n",
        "      self._optimizer_policy.zero_grad()\n",
        "      iters = torch.FloatTensor(np.sqrt(np.array(iterations)))\n",
        "      ac_probs = torch.FloatTensor(np.array(np.squeeze(action_probs)))\n",
        "      logits = self._policy_network(torch.FloatTensor(np.array(info_states)))\n",
        "      outputs = self._policy_sm(logits)\n",
        "      loss_strategy = self._loss_policy(iters * outputs, iters * ac_probs)\n",
        "      loss_strategy.backward()\n",
        "      self._optimizer_policy.step()\n",
        "\n",
        "    return loss_strategy.detach().numpy()\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F6g8XAmdgQuZ",
        "colab_type": "text"
      },
      "source": [
        "## Train the model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "46ZnC9qGq3bT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "game = pyspiel.load_game('kuhn_poker')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WqVsQUOqm9kp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "solver = DeepCFR(game,\n",
        "      policy_network_layers=(32, 32),\n",
        "      advantage_network_layers=(16, 16),\n",
        "      num_iterations=400,\n",
        "      num_traversals=40,\n",
        "      learning_rate=1e-3,\n",
        "      batch_size_advantage=None,\n",
        "      batch_size_strategy=None,\n",
        "      memory_capacity=1e7)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m-TnC_tuE5oP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "_, advantage_losses, policy_loss = solver.solve()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ljaGnchaLlfV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "for player, losses in list(advantage_losses.items()):\n",
        "  print(\"Advantage for player:\", player,\n",
        "                losses[:2] + [\"...\"] + losses[-2:])\n",
        "  print(\"Advantage Buffer Size for player\", player,\n",
        "                len(solver.advantage_buffers[player]))\n",
        "print(\"Strategy Buffer Size:\",\n",
        "              len(solver.strategy_buffer))\n",
        "print(\"Final policy loss:\", policy_loss)\n",
        "conv = exploitability.nash_conv(\n",
        "    game,\n",
        "    policy.tabular_policy_from_callable(game, solver.action_probabilities))\n",
        "print(\"Deep CFR - NashConv:\", conv)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "E9BRbD1-OINm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
